#include <layout.glsl>
#include "raytracing.glsl"
#include "mainScene.glsl"

#define BOUNCES 4

vec3 diffuse(vec3 normal, Ray ray) {
    vec3 light = vec3(0.0);

    for (int i = 0; i < 3; i++) {
        float emission = lights[i].color.w;
        vec3 color = lights[i].color.xyz;
        vec3 direction = lights[i].direction;
        float diffuse = clamp(dot(normal, direction), 0.0, 1.0);
        light += diffuse * emission * color;
    }

    float skyEmission = skyboxEmission(ray);
    vec3 skyColor = skyboxColor(ray);
    float skyDiffuse = clamp(dot(normal, -ray.direction), 0.0, 1.0);
    light += skyDiffuse * skyEmission * skyColor;

    return light;
}

////////////////////////////////////////////////
// Main tracing + marching loop
// Bounces, diffuse, emission, and reflections
vec3 bounces(Ray ray) {
    float distance = 0;
    vec3 point = vec3(0);
    vec3 absorptions[BOUNCES];
    vec3 emissions[BOUNCES];
    float roughnesses[BOUNCES];
    vec3 diffuses[BOUNCES];
    int i = 0;
    bool diffuseFinish = false;
    vec3 sampled = vec3(0);

    for(i = 0; i < BOUNCES && distance < MAXDISTANCE; i++) {
        distance = 0;
        vec4 result = march(ray);
        distance = result.x;
        point += ray.origin + ray.direction * distance;
        vec3 normal = computeNormal(point);
        int primaryMaterialId = int(result.y);
        int secondaryMaterialId = int(result.z);
        float materialBlend = result.w;

        if (primaryMaterialId == 0 || diffuseFinish) {
            emissions[i] = skyboxEmission(ray) * skyboxColor(ray);
            absorptions[i] = vec3(0);
            roughnesses[i] = 0.0;
            diffuses[i] = vec3(0.0);
            i++;
            break;
        } else {
            vec4 primaryEmission = materials[primaryMaterialId - 1].emissive;
            vec4 secondaryEmission = materials[secondaryMaterialId - 1].emissive;
            vec4 emission = mix(primaryEmission, secondaryEmission, materialBlend);

            vec4 primaryAbsorption = materials[primaryMaterialId - 1].color;
            vec4 secondaryAbsorption = materials[secondaryMaterialId - 1].color;
            vec4 absorption = mix(primaryAbsorption, secondaryAbsorption, materialBlend);
            float roughness = absorption.a;

            emissions[i] = emission.xyz;
            absorptions[i] = absorption.xyz;
            roughnesses[i] = roughness;
            diffuses[i] = diffuse(normal, ray);

            if (roughness >= 1 - EPSILON) {
                diffuseFinish = true;
            }

            ray.origin = point - 10 * EPSILON * ray.direction;
            ray.direction = reflect(ray.direction, normal);
        }
    }

    for (; i > 0; i--) {
        vec3 reflective = sampled * (1 - roughnesses[i - 1]);
        vec3 diffuse = diffuses[i - 1] * roughnesses[i - 1];
        sampled = reflective + diffuse;
        sampled = sampled * absorptions[i - 1] + emissions[i - 1];
    }

    return sampled;
}

float randomizeShutterTime() {
    float timeWindow = min(shutter.deltaTime, shutter.exposureTime);
    return time.current - mix(0.0, timeWindow, random.x);
}

ivec3 sampleCoordinates(int sampleIndex) {
    int index = (int(shutter.firstSample) - sampleIndex + 256) % 256;
    return ivec3(gl_FragCoord.xy, index);
}

void storeSample(vec3 color, int sampleIndex) {
    ivec3 coordinates = sampleCoordinates(sampleIndex);
    imageStore(taaImage, coordinates, vec4(color, 1.0));
}

vec4 integrateSamples(vec3 light) {
    int projectedSamplesCount = shutter.samplesPerFrame * int(shutter.exposureTime / shutter.deltaTime);
    int totalSamples = min(shutter.totalSamples, projectedSamplesCount);

    int count = 0;
    for (; count < totalSamples; count++) {
        int j = shutter.samplesPerFrame + count;
        ivec3 coordinates = sampleCoordinates(j);
        light += imageLoad(taaImage, coordinates).xyz;
    }

    return vec4(light / (shutter.samplesPerFrame + count), 1.0);
}

vec4 mainFragment() {
    vec2 normalizedPixel = gl_FragCoord.xy / vec2(resolution.x, resolution.y);
    initRandom(normalizedPixel, int(time.current * 0.016666));
    Ray ray = Ray(inputRayOrigin, normalize(inputRayDirection));
    shutterTime = time.current;
    vec3 light = vec3(0.0);

    for (int i = 0; i < shutter.samplesPerFrame; i++) {
        shutterTime = randomizeShutterTime();
        vec3 sampleLight = bounces(ray);
        light += sampleLight;
        storeSample(sampleLight, i);
        shuffle();
    }

    vec4 result = integrateSamples(light);

    float gamma = 2.2;
    result.rgb = pow(result.rgb, vec3(gamma));
    return result;
}
